clc;
clearvars;
load Data;
format short g


%=========== Define simulation parameters ===========%
max_iter = 100;
market_size = 100;
cluster = 1;
wbound = 0;
text_cluster = num2str(cluster);
text_size = num2str(market_size);
text_iter = num2str(max_iter);
text_wbound = num2str(wbound);

%=========== Define variables to store results ===========%
shareM = NaN(max_iter,9);
shareF = NaN(max_iter,9);
DivorceCosts = NaN(max_iter,12);
HHIndices = cell(max_iter,1);
StabilityIndices = cell(max_iter,1);


if cluster == 0
    text_solver = 'gurobi';
    parnum = 2;
else
    text_solver = 'mosek';
    parnum = 20;
    addpath(genpath("/opt/apps/easybuild/software/YALMIP/R20230609/"))
    addpath(genpath("~/mosek/10.1/toolbox/r2017a"))
end

%=========== Define chores = housework + childcare hours ===========%
Data.chorehm = Data.hworkm + Data.chcarem;
Data.chorehf = Data.hworkf + Data.chcaref;
Data.chorehm(Data.singlefemale == 1) = 0;
Data.chorehf(Data.singlemale == 1) = 0;


%=========== Define povertyline ===========%
Data.consumption = Data.q + Data.Q;
Data.percapitaconsumption(find(Data.couple == 1)) = Data.consumption(find(Data.couple == 1))/2;
Data.percapitaconsumption(find(Data.couple == 0)) = Data.consumption(find(Data.couple == 0));
povertyline = 0.6*median(Data.percapitaconsumption);


%=========== Define age restrictions for marriage market ===========%
age_diff = Data.agem-Data.agef;
age_diff = age_diff(Data.couple==1);
lb = prctile(age_diff,2.5);
ub = prctile(age_diff,97.5);


%===========  Define age, education and children categories for preference types ===========%
for i = 1:size(Data,1)
    if isnan(Data.agem(i))
        Data.agecatm(i) = 0;
    elseif 1 < Data.agem(i) &&  Data.agem(i) <= 35
        Data.agecatm(i) = 1;
    elseif 35 < Data.agem(i) &&  Data.agem(i) <= 50
        Data.agecatm(i) = 2;
    elseif 50 < Data.agem(i)
        Data.agecatm(i) = 3;
    end

    if isnan(Data.agef(i))
        Data.agecatf(i) = 0;
    elseif 1 < Data.agef(i) &&  Data.agef(i) <= 35
        Data.agecatf(i) = 1;
    elseif 35 < Data.agef(i) &&  Data.agef(i) <= 50
        Data.agecatf(i) = 2;
    elseif 50 < Data.agef(i)
        Data.agecatf(i) = 3;
    end

    if isnan(Data.edum(i))
        Data.educatm(i) = 0;
    else
        Data.educatm(i) = Data.edum(i);
    end

    if isnan(Data.eduf(i))
        Data.educatf(i) = 0;
    else
        Data.educatf(i) = Data.eduf(i);
    end

    if isnan(Data.nchild(i))
        Data.kidscat(i) = 0;
    elseif Data.nchild(i) == 0
        Data.kidscat(i) = 1;
    elseif Data.nchild(i) > 0
        Data.kidscat(i) = 2;
    end

    Data.typem(i) = Data.educatm(i)*100 + Data.kidscat(i)*10 + Data.agecatm(i);
    Data.typef(i) = Data.educatf(i)*100 + Data.kidscat(i)*10 + Data.agecatf(i);

    if Data.singlemale(i) == 1
        Data.typef(i) = 0;
    elseif Data.singlefemale(i) == 1
        Data.typem(i) = 0;
    end
end


%=========== Define household types and weights ===========%
Data.types = Data.typem.*1000 + Data.typef;
[C,~,ic] = unique(Data.types);
a_counts = accumarray(ic,1);
value_counts = [C, a_counts];
value_counts(:,3) = value_counts(:,2)./sum(value_counts(:,2));


%=========== Define bounds on wages ===========%
wagem_bound = [];
x = unique(Data.typem);
for i = 1:length(x)
   if x(i) ~= 0
    wagem_bound = [wagem_bound; mean(Data.wagem(find(Data.typem == x(i))),"omitnan"),std(Data.wagem(find(Data.typem == x(i))),"omitnan"), x(i)];
   end
end

wagef_bound = [];
x = unique(Data.typef);
for i = 1:length(x)
   if x(i) ~= 0
    wagef_bound = [wagef_bound; mean(Data.wagef(find(Data.typef == x(i))),"omitnan"),std(Data.wagef(find(Data.typef == x(i))),"omitnan"), x(i)];
   end
end


%=========== Define couple type based on education and employment status of spouses ===========%
for i = 1:size(Data,1)
    if Data.couple(i) == 1
        Data.hcatm(i) = 10*(1+Data.male_emp(i)) + Data.educatm(i);
        Data.hcatf(i) = 10*(1+Data.female_emp(i)) + Data.educatf(i);
        Data.hcat(i) = Data.hcatm(i)*100 + Data.hcatf(i);
    elseif Data.singlemale(i) == 1
        Data.hcat(i) = 0;
    elseif Data.singlefemale(i) == 1
        Data.hcat(i) = 0;
    end
end





%%
%=========== Estimate the model with subsamples ===========% 
sc = parallel.pool.Constant(RandStream('Threefry'));


parfor (iter = 1:max_iter,parnum)

    stream = sc.Value;
    stream.Substream = iter;

    complete = 0;
    while complete == 0

        %=========== randomly draw household types from their weighted distribution ===========% 
        rand_index = randsample(stream,value_counts(:,1), market_size, true, value_counts(:,3));
        [C,~,ic] = unique(rand_index);
        a_counts = accumarray(ic,1);
        index_counts = [C, a_counts];

        %=========== Given the number of household types required, randomly draw households of that type from the observed households ===========%
        Data_r = [];
        for i = 1:size(index_counts,1)
            Extract = Data(find(Data.types == index_counts(i,1)),:);
            Extract_index = randsample(stream,size(Extract,1),index_counts(i,2),true);
            Data_r = [Data_r; Extract(Extract_index,:)];
        end

       %=========== Define consideration set ===========% 
       consideration_set = get_consideration_sets(Data_r.couple,Data_r.singlemale,Data_r.singlefemale,Data_r.agem,Data_r.agef,lb,ub);
       
       %=========== Compute stability indices ===========% 
        fprintf('Computing Divorce Costs...\n');
    [temp_DivorceCostsM,temp_DivorceCostsF,temp_DivorceCostsMF,solved] = get_indices(Data_r.couple,Data_r.singlemale,Data_r.singlefemale,...
        Data_r.leisurem,Data_r.leisuref,Data_r.chorehm,Data_r.chorehf,Data_r.wagem,Data_r.wagef,Data_r.q,Data_r.Q,...
        Data_r.nonlabor,consideration_set,Data_r.typem,Data_r.typef,wagem_bound,wagef_bound,wbound/100,text_solver);
    
        %=========== Compute RICEBs ===========% 
        if solved == 0
            fprintf('Computing RICEBs...\n');
        [temp_min_male,temp_max_male,temp_min_female,temp_max_female, problem] = get_ricebs(Data_r.couple,Data_r.singlemale,Data_r.singlefemale,...
        Data_r.leisurem,Data_r.leisuref,Data_r.chorehm,Data_r.chorehf,Data_r.wagem,Data_r.wagef,Data_r.q,Data_r.Q,...
        Data_r.nonlabor,consideration_set,temp_DivorceCostsM,temp_DivorceCostsF,temp_DivorceCostsMF,Data_r.hcat,Data_r.typem,Data_r.typef,wagem_bound,wagef_bound,wbound/100,text_solver);
       
    if sum(~isnan(temp_min_male(:))) && sum(~isnan(temp_max_male(:))) && sum(~isnan(temp_min_female(:))) && sum(~isnan(temp_max_female(:))) && problem == 0

                %=========== Store stability indices ===========%     
                ncouple = length(find(Data_r.couple));
                temp_DivorceCostsMF(~consideration_set) = NaN;
                DC_temp = temp_DivorceCostsMF;

                DivorceCosts_maxm = max(DC_temp,[],2);
                DivorceCosts_maxm = DivorceCosts_maxm(find(Data_r.couple));
                DivorceCosts_maxf = max(DC_temp)';
                DivorceCosts_maxf = DivorceCosts_maxf(find(Data_r.couple));
                DivorceCosts_avgf = (nansum(DC_temp)')./(sum(~isnan(DC_temp))');
                DivorceCosts_avgf = DivorceCosts_avgf(find(Data_r.couple));
                DivorceCosts_avgm = (nansum(DC_temp,2))./(sum(~isnan(DC_temp),2));
                DivorceCosts_avgm = DivorceCosts_avgm(find(Data_r.couple));

                DivorceCosts_temp = zeros(ncouple,9);
                DivorceCosts_temp(:,1) = (temp_DivorceCostsM(find(Data_r.couple)))./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,2) = (temp_DivorceCostsF(find(Data_r.couple)))./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,3) = DivorceCosts_maxm./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,4) = DivorceCosts_maxf./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,5) = DivorceCosts_avgm./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,6) = DivorceCosts_avgf./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,7) = iter;
                mmsize = nansum(consideration_set,2);
                DivorceCosts_temp(:,8) = mmsize(find(Data_r.couple));
                mmsize = nansum(consideration_set)';
                DivorceCosts_temp(:,9) = mmsize(find(Data_r.couple));

                NBPStable = DivorceCosts_temp(:,5) <= 0.00001 & DivorceCosts_temp(:,6) <= 0.00001;
                IRMStable = DivorceCosts_temp(:,1) <= 0.00001;
                IRFStable = DivorceCosts_temp(:,2) <= 0.00001;                
                DivorceCosts(iter,:) = [mean(DivorceCosts_temp), mean(NBPStable), mean(IRMStable), mean(IRFStable)];

                temp_DivorceCostsM(find(~Data_r.couple)) = NaN;
                temp_DivorceCostsF(find(~Data_r.couple)) = NaN;
                DC_temp = [DC_temp, temp_DivorceCostsM];
                DC_temp = [DC_temp; [temp_DivorceCostsF',0]];
                StabilityIndices{iter} = DC_temp;
                HHIndices{iter} = Data_r.familyid;

                %=========== Store RICEBs ===========%  
                male_temp = zeros(1,2*size(temp_min_male,1)+1);
                for i = 1:size(temp_min_male,1)
                    male_temp(1,2*i-1) = temp_min_male(i);
                    male_temp(1,2*i) = temp_max_male(i);
                end
                male_temp(1,2*size(temp_min_male,1)+1) = iter;
                shareM(iter,:) = male_temp;

                female_temp = zeros(1,2*size(temp_min_female,1)+1);
                for i = 1:size(temp_min_female,1)
                    female_temp(1,2*i-1) = temp_min_female(i);
                    female_temp(1,2*i) = temp_max_female(i);
                end
                female_temp(1,2*size(temp_min_female,1)+1) = iter;
                shareF(iter,:) = female_temp;

                complete = 1;
                fprintf('*********** Iteration %d finished *************\n', iter);
            end

        end
    end
end

filename = ['DataPSID_results_ricebs_4cat_size',text_size,'_iter',text_iter,'_wbound',text_wbound,'_Cluster',text_cluster];
save(filename)



%===================== TABLE 25 =========================%
display([wagem_bound(:,1)-0.01*wbound.*wagem_bound(:,2), wagem_bound(:,1)+0.01*wbound.*wagem_bound(:,2)])
display([wagef_bound(:,1)-0.01*wbound.*wagef_bound(:,2), wagef_bound(:,1)+0.01*wbound.*wagef_bound(:,2)])


%===================== TABLE 26 =========================%
fprintf("IR male: %.2f \n", mean(DivorceCosts(:,1),"omitnan")*100);
fprintf("IR female: %.2f \n", mean(DivorceCosts(:,2),"omitnan")*100);
fprintf("NBP max male: %.2f \n", mean(DivorceCosts(:,3),"omitnan")*100);
fprintf("NBP max female: %.2f \n", mean(DivorceCosts(:,4),"omitnan")*100);
fprintf("NBP avg male: %.2f \n", mean(DivorceCosts(:,5),"omitnan")*100);
fprintf("NBP avg female: %.2f \n", mean(DivorceCosts(:,6),"omitnan")*100);

%===================== TABLE 27 =========================%
fprintf('Bounds for male employed = no and education = low = [%.4f, %.4f] \n', mean(shareM(:,1), "omitnan"), mean(shareM(:,2), "omitnan"));
fprintf('Bounds for male employed = no and education = high = [%.4f, %.4f] \n', mean(shareM(:,3), "omitnan"), mean(shareM(:,4), "omitnan"));
fprintf('Bounds for male employed = yes and education = low = [%.4f, %.4f] \n', mean(shareM(:,5), "omitnan"), mean(shareM(:,6), "omitnan"));
fprintf('Bounds for male employed = yes and education = high = [%.4f, %.4f] \n', mean(shareM(:,7), "omitnan"), mean(shareM(:,8), "omitnan"));

fprintf('Bounds for female employed = no and education = low = [%.4f, %.4f] \n', mean(shareF(:,1), "omitnan"), mean(shareF(:,2), "omitnan"));
fprintf('Bounds for female employed = no and education = high = [%.4f, %.4f] \n', mean(shareF(:,3), "omitnan"), mean(shareF(:,4), "omitnan"));
fprintf('Bounds for female employed = yes and education = low = [%.4f, %.4f] \n', mean(shareF(:,5), "omitnan"), mean(shareF(:,6), "omitnan"));
fprintf('Bounds for female employed = yes and education = high = [%.4f, %.4f] \n', mean(shareF(:,7), "omitnan"), mean(shareF(:,8), "omitnan"));

